%% Use OptDMD method to extract coherent strcutre from the secondary flow
% Sijie Sun et. cl. PNAS, 2023
% Based on Travis Askham 2017, OptDMD method
%% Choose Dataset
close all;clc;clear
load('Velocity_Field.mat')
addpath('./src');
%% Setup
nt=size(U,1);%number of frames
nx=size(U,2);%x dimension
ny=size(U,3);%y dimension
dt=t(2)-t(1);
mm=1e-3;%1 mm = 1e-3 m
%% Preview Velocity Field
figure()    ;hold on
quiver(squeeze(mean(U(1:500,:,:),1)),squeeze(mean(V(1:500,:,:),1)),2,'LineWidth',1.5)
axis equal
xlabel('x')
ylabel('y')
%% Merge 2D Velocity field into one vector
u=zeros(nx*ny*2,nt);
for i=1:nt
    u(:,i)=UV2X(reshape(U(i,:,:),[],1),reshape(V(i,:,:),[],1));
end
%% SVD decomposition with truncation for reference
[U, S, V]=svd(u,'econ');
figure()
stem(diag(S).^2);% Importances of the first ten SVD modes
set(gca,'YScale','log')
xlim([0,10])
% truncation at r
r =21;
U = U(:,1:r);
S = S(1:r,1:r);
V = V(:,1:r);
uhat=U*S*V';
figure()
hold on
plot(t,mean(u,1),'r-','DisplayName','Ground Truth')
plot(t,mean(uhat,1),'b--','linewidth',2,'DisplayName','SVD with truncation')
xlim([0,60])
xlabel('Time-s')
ylabel('Speed-m/s')

legend('Location','best')
set(gca,'fontsize',16)

%% Optimized DMD, On truncated modes
r=3;
OPT.maxiter=10000;
[Phi_optDMD,eigs_optDMD,b_optDMD]=optdmd(u, t,r,2,OPT);
uxhat1=zeros(1,nt);
for j=1:nt
    reconstructed=zeros(size(u(:,1)));
    for i=1:r
        reconstructed=reconstructed+b_optDMD(i)*exp(eigs_optDMD(i)*dt*j)*Phi_optDMD(:,i);
    end    
    uxhat1(j)=mean(real(reconstructed(1:2:end)));
end

r=21;
[Phi_optDMD,eigs_optDMD,b_optDMD]=optdmd(u, t,r,2,OPT);
uxhat2=zeros(1,nt);
ux_true=zeros(1,nt);
for j=1:nt
    reconstructed=zeros(size(u(:,1)));
    for i=1:r
        reconstructed=reconstructed+b_optDMD(i)*exp(eigs_optDMD(i)*dt*j)*Phi_optDMD(:,i);
    end    

    uxhat2(j)=mean(real(reconstructed(1:2:end)));
    ux_true(j)=mean(u(1:2:end,j));
end
%% Compare decomposited flow with the measured flow

f=figure('PaperPositionMode',"auto",'Units',"inches",'Position',[1,1,3,2.5]);

hold on
a=plot(t,ux_true/mm,'DisplayName','Measured value','linewidth',2,'Color','#377eb8');
b=plot(t,uxhat1/mm,'DisplayName','Two structures','linewidth',2,'Color','#e41a1c');
c=plot(t,uxhat2/mm,'DisplayName','Reconstructed value','linewidth',2,'Color','#4daf4a');
a.Color=[a.Color,0.5];
b.Color=[b.Color,0.5];
c.Color=[c.Color,0.5];


hold off
legend('Location','best')
legend 'boxoff'
set(gca,'fontsize',10)
xlabel('Time-s')
ylabel('Frame Averaged Ux-mm/s')
xlim([0,20])
ylim([-5,3])

box on
set(gca,'linewidth',1)
print(gcf,['DMD reconstruction '],'-dpng','-r500')
%% Check L2 Difference
L2difference=0;
L2true=0;
for j=1:nt
    reconstructed=zeros(size(u(:,1)));
    for i=1:r
        reconstructed=reconstructed+b_optDMD(i)*exp(eigs_optDMD(i)*dt*j)*Phi_optDMD(:,i);
    end    
    L2difference=L2difference+norm(u(:,j)-reconstructed);
    L2true=L2true+norm(u(:,j));
end
disp(['relative L2 difference is ',num2str(L2difference/L2true),'L2 original is ',num2str(L2true)])

%% Plot  DMD spectrum
f2=figure();
hold on

theta = (0:1:100)*2*pi/100;
plot(cos(theta),sin(theta),'k--') % plot unit circle
scatter(real(exp(eigs_optDMD*dt)),imag(exp(eigs_optDMD*dt)),'ok')
axis([-1.1 1.1 -1.1 1.1]);
xlabel('Re of Eigen Value')
ylabel('Im of Eigen Value')
title('DMD Eigen Value Distribution')
axis equal
print(f2,['Spectrum of Modes'],'-dpng','-r500')

%% Plot DMD Mode Importance Distribution
f3=figure();
[Y,od1]=sort(abs(b_optDMD),'descend');
bar(Y(1:11)/sum(Y),'linewidth',0.0001)
ylim([5e-3,1])
set(gca,'yscale','log')
box on; grid on
xlabel('Mode Number')
ylabel('Contribution')
title('Mode Contribution Distribution')
%% Plot Flow structure Importance Distribution
Y1=[0;Y];Y2=zeros(11,0);
for i=1:11
    Y2(i)=sum(Y1(i*2-1:i*2));
end
f4=figure();
bar(Y2/sum(Y),'linewidth',0.0001)
ylim([1e-3,1])
set(gca,'yscale','log')
box on; grid off
xlabel('Structure Number')
ylabel('Contribution')
title('Structure contribution Distribution')
set(gca,'fontsize',14)
set(gca,'linewidth',2)
print(f4,['Structure Contribution Distribution'],'-dpng','-r500')

%% Plot the eigen modes
Scale=3;
for j=1:5
    i=od1(j);
    [f,ax1,ax2]=plotVectorfield(Phi_optDMD(:,i),nx,ny,exp(eigs_optDMD(i)*dt),dt,Scale);
end

%% Auxiliary Functions

function X=UV2X(U,V)
    X=zeros(length(U)*2,1);
    X(1:2:end-1)=U;
    X(2:2:end)=V;
end

function [f,ax1,ax2]=plotVectorfield(X,nx,ny,eig_val,dt,Scale)
    Uxr=reshape(real(X(1:2:end-1)),nx,ny);
    Uyr=reshape(real(X(2:2:end)),nx,ny);
    Uxi=reshape(imag(X(1:2:end-1)),nx,ny);
    Uyi=reshape(imag(X(2:2:end)),nx,ny);    
    Lr=norm(real(X(:)));
    Li=norm(imag(X(:)));
    f=figure('Position',[100,100,1600,700]);
    ax1=subplot(1,2,1);
        quiver(Uxr/Lr,Uyr/Lr,Scale,'LineWidth',1.5)
        axis equal
        title('Real Part')
        xlabel('X - No Unit')
        ylabel('Y - No Unit')
        set(gca,'linewidth',2)
        set(gca,'fontsize',16)    
    ax2=subplot(1,2,2);
        quiver(Uxi/Li,Uyi/Li,Scale,'LineWidth',1.5)
        axis equal
        xlabel('X - No Unit')
        ylabel('Y - No Unit')
        set(gca,'linewidth',2)
        set(gca,'fontsize',16)
        title('Imaginary Part')
    dtheta=angle(eig_val);
    omega=dtheta/dt;
    tau=2*pi/omega;
    sgtitle(['% \lambda = ',num2str(eig_val) ,' \tau =' ,num2str(abs(tau)),' s'],'fontsize',20) 
    print(f,['Velocity field of Sorted Mode ' num2str(j)],'-dpng','-r300')
end


